import sys
import numpy as np
import scipy
import sklearn
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
from utils_algo import *
from models import *
from sklearn.model_selection import train_test_split
 
np.random.seed(0); torch.manual_seed(0); torch.cuda.manual_seed_all(0)


def confidence_generator(train_loader, SC_train_loader, ordered_class, ratio_loader, noise):
    model = mlp3_model(input_dim=28*28, hidden_dim=100, output_dim=10)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    for epoch in range(20):
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images).to(device)
            loss_function = nn.CrossEntropyLoss()
            loss = loss_function(outputs, labels).to(device)
            loss = loss + 0.01 * torch.sum(torch.mul(torch.softmax(outputs,dim=-1).to(device),torch.softmax(outputs,dim=-1).to(device)).to(device)).to(device)
            loss.backward()
            optimizer.step()

    conf = torch.tensor([])
    conf = conf.to(device)
    for i, (images, labels) in enumerate(SC_train_loader):
        images, labels = images.to(device), labels.to(device)
        a = model(images).detach().to(device)
        c = torch.softmax(a, dim=-1).to(device)
        conf = torch.cat((conf, c), 0).to(device)

    if noise:
        for i in range(len(conf)):
            ind = torch.max(conf, dim=1)[1].to(device)

        conf = conf * 0
        for i in range(len(conf)):
            conf[i][ind[i]] = 1

    weight = conf.clone()
    for i in range(len(conf)):
        normalization = 0
        for j in ordered_class:
            normalization = normalization + weight[i][j]
        weight[i] = weight[i]/normalization

    weight2 = conf.clone()
    model2 = mlp3_model(input_dim=28 * 28, hidden_dim=100, output_dim=1)
    model2 = model2.to(device)
    optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-4, weight_decay=1e-4)
    for epoch in range(10):
        for i, (images, labels) in enumerate(ratio_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer2.zero_grad()
            outputs = model2(images).to(device)
            outputs = torch.log(torch.exp(1+torch.exp(outputs)).to(device))
            s1 = (labels==0)
            s2 = (labels==1)
            nsc = torch.sum(s1).to(device)
            nu = torch.sum(s2).to(device)
            loss_sc = torch.sum(torch.mul(outputs[s1], outputs[s1]).to(device)).to(device)
            loss_u = -2*torch.sum(outputs[s2]).to(device)

            loss = loss_sc/nsc+loss_u/nu
            loss.backward()
            optimizer2.step()
    coeff = torch.tensor([]).to(device)
    for i, (images, labels) in enumerate(SC_train_loader):
        images, labels = images.to(device), labels.to(device)
        a = model2(images).detach().to(device)
        a = torch.log(1 + torch.exp(a)).to(device)
        coeff = torch.cat((coeff, a), 0).to(device)
    weight2 = torch.mul(weight2, coeff).to(device)
    return weight, weight2, conf


def prepare_mnist_data(batch_size, ordered_class, noise):

    ordinary_train_dataset = dsets.MNIST(root='./data/mnist', train=True, transform=transforms.ToTensor(), download=True)
    train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=len(ordinary_train_dataset), shuffle=True)
    for i, (image, labels) in enumerate(train_loader):
        continue

    train_data, gen_conf_data, train_label, gen_conf_label = train_test_split(image, labels, stratify=labels, test_size=0.3)
    train_dataset_for_conf_gen = torch.utils.data.TensorDataset(gen_conf_data, gen_conf_label)
    train_loader_for_conf_gen = torch.utils.data.DataLoader(dataset=train_dataset_for_conf_gen, batch_size=batch_size, shuffle=True)

    selected = []
    for i in range(len(labels)):
        if labels[i] in ordered_class:
            selected.append(i)
        continue

    sc_data, u_data, sc_label, u_label = train_test_split(train_data, train_label, stratify=train_label, test_size=0.1428)
    selected2 = []
    for i in range(len(sc_label)):
        if sc_label[i] in ordered_class:
            selected2.append(i)
        continue
    sc_data = sc_data[selected2]
    image_ratio = torch.cat((sc_data,u_data),0).to(device)
    label_ratio = torch.cat((torch.zeros(len(sc_label[selected2])),torch.ones(len(u_label))),0).to(device)

    image_sc = image[selected]
    label_sc = labels[selected]
    ratio_dataset = torch.utils.data.TensorDataset(image_ratio, label_ratio)
    ratio_loader = torch.utils.data.DataLoader(dataset=ratio_dataset, batch_size=len(ratio_dataset), shuffle=True)

    SC_train_dataset = torch.utils.data.TensorDataset(image_sc, label_sc)
    SC_train_loader = torch.utils.data.DataLoader(dataset=SC_train_dataset, batch_size=batch_size, shuffle=False)
    weight, weight2, conf = confidence_generator(train_loader_for_conf_gen, SC_train_loader, ordered_class, ratio_loader, noise)

    train_dataset = torch.utils.data.TensorDataset(image_sc, weight)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    naive_train_dataset = torch.utils.data.TensorDataset(image_sc, conf)
    naive_train_loader = torch.utils.data.DataLoader(dataset=naive_train_dataset, batch_size=batch_size, shuffle=True)
    norsc_train_dataset = torch.utils.data.TensorDataset(image_sc, weight2)
    norsc_train_loader = torch.utils.data.DataLoader(dataset=norsc_train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = dsets.MNIST(root='./data/mnist', train=False, transform=transforms.ToTensor(), download=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, naive_train_loader, norsc_train_loader, test_loader


